(Beta) Implementing High

您所在的位置:网站首页 pytorch new backend (Beta) Implementing High

(Beta) Implementing High

2024-01-27 17:44| 来源: 网络整理| 查看: 265

Note

Click here to download the full example code

(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)¶

Author: Driss Guessous

Summary¶

In this tutorial, we want to highlight a new torch.nn.functional function that can be helpful for implementing transformer architectures. The function is named torch.nn.functional.scaled_dot_product_attention. For detailed description of the function, see the PyTorch documentation. This function has already been incorporated into torch.nn.MultiheadAttention and torch.nn.TransformerEncoderLayer.

Overview¶

At a high level, this PyTorch function calculates the scaled dot product attention (SDPA) between query, key, and value according to the definition found in the paper Attention is all you need. While this function can be written in PyTorch using existing functions, a fused implementation can provide large performance benefits over a naive implementation.

Fused implementations¶

For CUDA tensor inputs, the function will dispatch into one of the following implementations:

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Memory-Efficient Attention

A PyTorch implementation defined in C++

Note

This tutorial requires PyTorch 2.0.0 or later.

import torch import torch.nn as nn import torch.nn.functional as F device = "cuda" if torch.cuda.is_available() else "cpu" # Example Usage: query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device) F.scaled_dot_product_attention(query, key, value) tensor([[[-1.3321, -0.3489, 0.3015, -0.3912, 0.9867, 0.3137, -0.0691, -1.2593], [-1.0882, 0.2506, 0.6491, 0.1360, 0.5238, -0.2448, -0.0820, -0.6171], [-1.0012, 0.3990, 0.6441, -0.0277, 0.5325, -0.2564, -0.0607, -0.6404]], [[ 0.6091, 0.0708, 0.6188, 0.3252, -0.1598, 0.4197, -0.2335, 0.0630], [ 0.5285, 0.3890, -0.2649, 0.3706, -0.3839, 0.1963, -0.6242, 0.2312], [ 0.4048, 0.0762, 0.3777, 0.4689, -0.2978, 0.2754, -0.6429, 0.1037]]], device='cuda:0') Explicit Dispatcher Control¶

While the function will implicitly dispatch to one of the three implementations, the user can also explicitly control the dispatch via the use of a context manager. This context manager allows users to explicitly disable certain implementations. If a user wants to ensure the function is indeed using the fastest implementation for their specific inputs, the context manager can be used to sweep through measuring performance.

# Lets define a helpful benchmarking function: import torch.utils.benchmark as benchmark def benchmark_torch_function_in_microseconds(f, *args, **kwargs): t0 = benchmark.Timer( stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} ) return t0.blocked_autorange().mean * 1e6 # Lets define the hyper-parameters of our input batch_size = 32 max_sequence_len = 1024 num_heads = 32 embed_dimension = 32 dtype = torch.float16 query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype) key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype) value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype) print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") # Lets explore the speed of each of the 3 implementations from torch.backends.cuda import sdp_kernel, SDPBackend # Helpful arguments mapper backend_map = { SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False}, SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False}, SDPBackend.EFFICIENT_ATTENTION: { "enable_math": False, "enable_flash": False, "enable_mem_efficient": True} } with sdp_kernel(**backend_map[SDPBackend.MATH]): print(f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): try: print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") except RuntimeError: print("FlashAttention is not supported. See warnings for reasons.") with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): try: print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") except RuntimeError: print("EfficientAttention is not supported. See warnings for reasons.") The default implementation runs in 2262.411 microseconds The math implementation runs in 19253.835 microseconds The flash attention implementation runs in 2262.734 microseconds The memory efficient implementation runs in 4202.967 microseconds Hardware dependence¶

Depending on what machine you ran the above cell on and what hardware is available, your results might be different. - If you don’t have a GPU and are running on CPU then the context manager will have no effect and all three runs should return similar timings. - Depending on what compute capability your graphics card supports flash attention or memory efficient might have failed.

Causal Self Attention¶

Below is an example implementation of a multi-headed causal self attention block inspired by Andrej Karpathy NanoGPT repository.

class CausalSelfAttention(nn.Module): def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0): super().__init__() assert embed_dimension % num_heads == 0 # key, query, value projections for all heads, but in a batch self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias) # output projection self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias) # regularization self.dropout = dropout self.resid_dropout = nn.Dropout(dropout) self.num_heads = num_heads self.embed_dimension = embed_dimension # Perform causal masking self.is_causal = is_causal def forward(self, x): # calculate query, key, values for all heads in batch and move head forward to be the batch dim query_projected = self.c_attn(x) batch_size = query_projected.size(0) embed_dim = query_projected.size(2) head_dim = embed_dim // (self.num_heads * 3) query, key, value = query_projected.chunk(3, -1) query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) if self.training: dropout = self.dropout is_causal = self.is_causal else: dropout = 0.0 is_causal = False y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal) y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim) y = self.resid_dropout(self.c_proj(y)) return y num_heads = 8 heads_per_dim = 64 embed_dimension = num_heads * heads_per_dim dtype = torch.float16 model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval() print(model) CausalSelfAttention( (c_attn): Linear(in_features=512, out_features=1536, bias=False) (c_proj): Linear(in_features=512, out_features=512, bias=False) (resid_dropout): Dropout(p=0.1, inplace=False) ) NestedTensor and Dense tensor support¶

SDPA supports both NestedTensor and Dense tensor inputs. NestedTensors handle the case where the input is a batch of variable length sequences without needing to pad each sequence to the maximum length in the batch. For more information about NestedTensors see torch.nested and NestedTensors Tutorial.

import random def generate_rand_batch( batch_size, max_sequence_len, embed_dimension, pad_percentage=None, dtype=torch.float16, device="cuda", ): if not pad_percentage: return ( torch.randn( batch_size, max_sequence_len, embed_dimension, dtype=dtype, device=device, ), None, ) # Random sequence lengths seq_len_list = [ int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01))) for _ in range(batch_size) ] # Make random entry in the batch have max sequence length seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len return ( torch.nested.nested_tensor( [ torch.randn(seq_len, embed_dimension, dtype=dtype, device=device) for seq_len in seq_len_list ] ), seq_len_list, ) random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device) random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device) # Currently the fused implementations don't support ``NestedTensor`` for training model.eval() with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): try: print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds") print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds") except RuntimeError: print("FlashAttention is not supported. See warnings for reasons.") /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nested/__init__.py:166: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:177.) Random NT runs in 561.060 microseconds Random Dense runs in 937.139 microseconds Using SDPA with torch.compile¶

With the release of PyTorch 2.0, a new feature called torch.compile() has been introduced, which can provide significant performance improvements over eager mode. Scaled dot product attention is fully composable with torch.compile(). To demonstrate this, let’s compile the CausalSelfAttention module using torch.compile() and observe the resulting performance improvements.

batch_size = 32 max_sequence_len = 256 x = torch.rand(batch_size, max_sequence_len, embed_dimension, device=device, dtype=dtype) print( f"The non compiled module runs in {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds") compiled_model = torch.compile(model) # Let's compile it compiled_model(x) print( f"The compiled module runs in {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds") The non compiled module runs in 407.485 microseconds The compiled module runs in 523.554 microseconds

The exact execution time is dependent on machine, however the results for mine: The non compiled module runs in 166.616 microseconds The compiled module runs in 166.726 microseconds That is not what we were expecting. Let’s dig a little deeper. PyTorch comes with an amazing built-in profiler that you can use to inspect the performance characteristics of your code.

from torch.profiler import profile, record_function, ProfilerActivity activities = [ProfilerActivity.CPU] if device == 'cuda': activities.append(ProfilerActivity.CUDA) with profile(activities=activities, record_shapes=False) as prof: with record_function(" Non-Compilied Causal Attention"): for _ in range(25): model(x) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) with profile(activities=activities, record_shapes=False) as prof: with record_function("Compiled Causal Attention"): for _ in range(25): compiled_model(x) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) # For even more insights, you can export the trace and use ``chrome://tracing`` to view the results # # .. code-block:: python # # prof.export_chrome_trace("compiled_causal_attention_trace.json"). ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Non-Compilied Causal Attention 19.62% 2.234ms 76.88% 8.755ms 8.755ms 0.000us 0.00% 11.053ms 11.053ms 1 aten::linear 1.25% 142.000us 30.28% 3.448ms 68.960us 0.000us 0.00% 8.166ms 163.320us 50 aten::matmul 2.34% 267.000us 26.36% 3.002ms 60.040us 0.000us 0.00% 8.166ms 163.320us 50 aten::mm 17.96% 2.045ms 22.00% 2.505ms 50.100us 7.760ms 76.48% 8.166ms 163.320us 50 ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn 0.00% 0.000us 0.00% 0.000us 0.000us 5.559ms 54.78% 5.559ms 222.360us 25 aten::scaled_dot_product_attention 1.88% 214.000us 18.21% 2.074ms 82.960us 0.000us 0.00% 2.887ms 115.480us 25 aten::_scaled_dot_product_flash_attention 3.49% 397.000us 16.33% 1.860ms 74.400us 0.000us 0.00% 2.887ms 115.480us 25 aten::_flash_attention_forward 4.40% 501.000us 11.85% 1.350ms 54.000us 2.387ms 23.52% 2.887ms 115.480us 25 void pytorch_flash::flash_fwd_kernel


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3